{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7f6607dd",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/jerryjliu/llama_index/blob/main/docs/examples/query_engine/CustomRetrievers.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "5eb0fa8d-51a3-4b32-86c6-0df4baead164",
   "metadata": {},
   "source": [
    "# Retriever Query Engine with Custom Retrievers - Simple Hybrid Search\n",
    "\n",
    "In this tutorial, we show you how to define a very simple version of hybrid search! \n",
    "\n",
    "Combine keyword lookup retrieval with vector retrieval using \"AND\" and \"OR\" conditions."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "9393b260-57e9-4134-8793-b3423bc891ca",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e0f9dee",
   "metadata": {},
   "source": [
    "If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "feffca36",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install llama-index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc89539b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31f72acb",
   "metadata": {},
   "source": [
    "### Download Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f44492d6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Will not apply HSTS. The HSTS database must be a regular and non-world-writable file.\n",
      "ERROR: could not open HSTS store at '/home/loganm/.wget-hsts'. HSTS will be disabled.\n",
      "--2023-11-23 12:54:37--  https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/paul_graham/paul_graham_essay.txt\n",
      "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.108.133, ...\n",
      "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 75042 (73K) [text/plain]\n",
      "Saving to: ‘data/paul_graham/paul_graham_essay.txt’\n",
      "\n",
      "data/paul_graham/pa 100%[===================>]  73.28K  --.-KB/s    in 0.04s   \n",
      "\n",
      "2023-11-23 12:54:37 (1.77 MB/s) - ‘data/paul_graham/paul_graham_essay.txt’ saved [75042/75042]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!mkdir -p 'data/paul_graham/'\n",
    "!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/paul_graham/paul_graham_essay.txt' -O 'data/paul_graham/paul_graham_essay.txt'"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ec6a93b3-04e1-4c21-95d8-b93873e68fad",
   "metadata": {},
   "source": [
    "### Load Data\n",
    "\n",
    "We first show how to convert a Document into a set of Nodes, and insert into a DocumentStore."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1f2276b-03f3-4dd9-9178-ed274d465c17",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import SimpleDirectoryReader\n",
    "\n",
    "# load documents\n",
    "documents = SimpleDirectoryReader(\"./data/paul_graham\").load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "571a5862-0f6f-42b6-9975-2426f1ee8b8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import Settings\n",
    "\n",
    "nodes = Settings.get_nodes_from_documents(documents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e61b9f42-a8a0-491f-8342-3cc366689c41",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import StorageContext\n",
    "\n",
    "# initialize storage context (by default it's in-memory)\n",
    "storage_context = StorageContext.from_defaults()\n",
    "storage_context.docstore.add_documents(nodes)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "88fb9077-c4aa-4a52-ac8d-75542ab27501",
   "metadata": {},
   "source": [
    "### Define Vector Index and Keyword Table Index over Same Data\n",
    "\n",
    "We build a vector index and keyword index over the same DocumentStore"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2f3bb70-0c1d-42c0-9b66-c73c0056339f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n",
      "HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n",
      "INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n",
      "HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n",
      "INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n",
      "HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n"
     ]
    }
   ],
   "source": [
    "from llama_index.core import SimpleKeywordTableIndex, VectorStoreIndex\n",
    "\n",
    "vector_index = VectorStoreIndex(nodes, storage_context=storage_context)\n",
    "keyword_index = SimpleKeywordTableIndex(nodes, storage_context=storage_context)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "6df55ba7-8070-4149-b786-3d3e64b7d64f",
   "metadata": {},
   "source": [
    "### Define Custom Retriever\n",
    "\n",
    "We now define a custom retriever class that can implement basic hybrid search with both keyword lookup and semantic search.\n",
    "\n",
    "- setting \"AND\" means we take the intersection of the two retrieved sets\n",
    "- setting \"OR\" means we take the union"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b45b406e-c42c-48d2-bd62-228cf62b3b43",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import QueryBundle\n",
    "from llama_index.core import QueryBundle\n",
    "\n",
    "# import NodeWithScore\n",
    "from llama_index.core.schema import NodeWithScore\n",
    "\n",
    "# Retrievers\n",
    "from llama_index.core.retrievers import (\n",
    "    BaseRetriever,\n",
    "    VectorIndexRetriever,\n",
    "    KeywordTableSimpleRetriever,\n",
    ")\n",
    "\n",
    "from typing import List"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b17f16c-6059-4855-9417-a00bb41d4c12",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomRetriever(BaseRetriever):\n",
    "    \"\"\"Custom retriever that performs both semantic search and hybrid search.\"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        vector_retriever: VectorIndexRetriever,\n",
    "        keyword_retriever: KeywordTableSimpleRetriever,\n",
    "        mode: str = \"AND\",\n",
    "    ) -> None:\n",
    "        \"\"\"Init params.\"\"\"\n",
    "\n",
    "        self._vector_retriever = vector_retriever\n",
    "        self._keyword_retriever = keyword_retriever\n",
    "        if mode not in (\"AND\", \"OR\"):\n",
    "            raise ValueError(\"Invalid mode.\")\n",
    "        self._mode = mode\n",
    "        super().__init__()\n",
    "\n",
    "    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:\n",
    "        \"\"\"Retrieve nodes given query.\"\"\"\n",
    "\n",
    "        vector_nodes = self._vector_retriever.retrieve(query_bundle)\n",
    "        keyword_nodes = self._keyword_retriever.retrieve(query_bundle)\n",
    "\n",
    "        vector_ids = {n.node.node_id for n in vector_nodes}\n",
    "        keyword_ids = {n.node.node_id for n in keyword_nodes}\n",
    "\n",
    "        combined_dict = {n.node.node_id: n for n in vector_nodes}\n",
    "        combined_dict.update({n.node.node_id: n for n in keyword_nodes})\n",
    "\n",
    "        if self._mode == \"AND\":\n",
    "            retrieve_ids = vector_ids.intersection(keyword_ids)\n",
    "        else:\n",
    "            retrieve_ids = vector_ids.union(keyword_ids)\n",
    "\n",
    "        retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]\n",
    "        return retrieve_nodes"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "d111dc0c-a5f7-44d1-b9b8-fbdb36683040",
   "metadata": {},
   "source": [
    "### Plugin Retriever into Query Engine\n",
    "\n",
    "Plugin retriever into a query engine, and run some queries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d328d7e-4537-4007-852d-e575e4d11110",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import get_response_synthesizer\n",
    "from llama_index.core.query_engine import RetrieverQueryEngine\n",
    "\n",
    "# define custom retriever\n",
    "vector_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=2)\n",
    "keyword_retriever = KeywordTableSimpleRetriever(index=keyword_index)\n",
    "custom_retriever = CustomRetriever(vector_retriever, keyword_retriever)\n",
    "\n",
    "# define response synthesizer\n",
    "response_synthesizer = get_response_synthesizer()\n",
    "\n",
    "# assemble query engine\n",
    "custom_query_engine = RetrieverQueryEngine(\n",
    "    retriever=custom_retriever,\n",
    "    response_synthesizer=response_synthesizer,\n",
    ")\n",
    "\n",
    "# vector query engine\n",
    "vector_query_engine = RetrieverQueryEngine(\n",
    "    retriever=vector_retriever,\n",
    "    response_synthesizer=response_synthesizer,\n",
    ")\n",
    "# keyword query engine\n",
    "keyword_query_engine = RetrieverQueryEngine(\n",
    "    retriever=keyword_retriever,\n",
    "    response_synthesizer=response_synthesizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cbc3209-cb4e-4a0a-ada8-1d8eff9e657e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n",
      "HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n",
      "INFO:llama_index.indices.keyword_table.retrievers:> Starting query: What did the author do during his time at YC?\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> Starting query: What did the author do during his time at YC?\n",
      "INFO:llama_index.indices.keyword_table.retrievers:query keywords: ['author', 'yc', 'time']\n",
      "query keywords: ['author', 'yc', 'time']\n",
      "INFO:llama_index.indices.keyword_table.retrievers:> Extracted keywords: ['yc', 'time']\n",
      "> Extracted keywords: ['yc', 'time']\n",
      "INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n"
     ]
    }
   ],
   "source": [
    "response = custom_query_engine.query(\n",
    "    \"What did the author do during his time at YC?\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3f74212-42b1-40f0-b9aa-98098ad75a22",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "During his time at YC, the author worked on various projects, including writing essays and working on YC itself. He also wrote all of YC's internal software in Arc. Additionally, he mentioned that he dealt with urgent problems, with about a 60% chance of them being related to Hacker News (HN), and a 40% chance of them being related to everything else combined. The author also mentioned that YC was different from other kinds of work he had done, as the problems of the startups in each batch became their problems, and he worked hard even at the parts of the job he didn't like.\n"
     ]
    }
   ],
   "source": [
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48eb7360-644d-408a-8c43-8ddebbdebb64",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n",
      "HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n",
      "INFO:llama_index.indices.keyword_table.retrievers:> Starting query: What did the author do during his time at Yale?\n",
      "> Starting query: What did the author do during his time at Yale?\n",
      "INFO:llama_index.indices.keyword_table.retrievers:query keywords: ['author', 'yale', 'time']\n",
      "query keywords: ['author', 'yale', 'time']\n",
      "INFO:llama_index.indices.keyword_table.retrievers:> Extracted keywords: ['time']\n",
      "> Extracted keywords: ['time']\n"
     ]
    }
   ],
   "source": [
    "# hybrid search can allow us to not retrieve nodes that are irrelevant\n",
    "# Yale is never mentioned in the essay\n",
    "response = custom_query_engine.query(\n",
    "    \"What did the author do during his time at Yale?\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95da0140-19f7-4533-befb-6153c9bda550",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Empty Response\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(str(response))\n",
    "len(response.source_nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6507bf10-2c0c-484b-bea9-fbd60354fcad",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n",
      "HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n",
      "INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n",
      "HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n"
     ]
    }
   ],
   "source": [
    "# in contrast, vector search will return an answer\n",
    "response = vector_query_engine.query(\n",
    "    \"What did the author do during his time at Yale?\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1696bac-7978-4e2a-ad1b-84e5b2d5a1d4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The context information does not provide any information about the author's time at Yale.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(str(response))\n",
    "len(response.source_nodes)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama-index-4a-wkI5X-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
